package hex;

import hex.deeplearning.*;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.*;
import water.fvec.*;
import water.util.Log;

import java.util.HashSet;

public class DeepLearningAutoEncoderTest extends TestUtil {
  /*
    Visualize outliers with the following R code (from smalldata/anomaly dir):

    train <- scan("ecg_discord_train.csv", sep=",")
    test  <- scan("ecg_discord_test.csv",  sep=",")
    plot.ts(train)
    plot.ts(test)
  */

  static final String PATH = "smalldata/anomaly/ecg_discord_train.csv"; //first 20 points
  static final String PATH2 = "smalldata/anomaly/ecg_discord_test.csv"; //first 22 points

  @BeforeClass public static void stall() {
    stall_till_cloudsize(JUnitRunnerDebug.NODES);
  }

  @Test
  public void run() {
    long seed = 0xDECAF;

    Frame train=null, test=null;

    try {
      Key file_train = NFSFileVec.make(find_test_file(PATH));
      train = ParseDataset2.parse(Key.make(), new Key[]{file_train});
      Key file_test = NFSFileVec.make(find_test_file(PATH2));
      test = ParseDataset2.parse(Key.make(), new Key[]{file_test});

      for (float sparsity_beta : new float[]{0, 0.1f}) {
        DeepLearning p = new DeepLearning();
        p.source = train;
        p.autoencoder = true;
        p.response = train.lastVec();
        p.classification = false;

        p.seed = seed;
        p.hidden = new int[]{100, 100};
        p.adaptive_rate = true;
        p.train_samples_per_iteration = -1;
        p.sparsity_beta = sparsity_beta;
        p.average_activation = -0.7;
        p.l1 = 1e-4;
//    p.l2 = 1e-4;
//    p.rate = 1e-5;
        p.activation = DeepLearning.Activation.Tanh;
        p.loss = DeepLearning.Loss.MeanSquare;
//    p.initial_weight_distribution = DeepLearning.InitialWeightDistribution.Normal;
//    p.initial_weight_scale = 1e-3;
        p.epochs = 500;
//    p.shuffle_training_data = true;
        p.force_load_balance = false; //if enabled, Hogwild gets ugly on many cores
        p.invoke();

        DeepLearningModel mymodel = UKV.get(p.dest());
        Frame l2_frame_train=null, l2_frame_test=null;

        // Verification of results
        StringBuilder sb = new StringBuilder();
        try {

          sb.append("Verifying results.\n");

          // Training data

          // Reconstruct data using the same helper functions and verify that self-reported MSE agrees
          double quantile = 0.95;
          l2_frame_train = mymodel.scoreAutoEncoder(train);
          final Vec l2_train = l2_frame_train.anyVec();
          sb.append("Mean reconstruction error: " + l2_train.mean() + "\n");
          Assert.assertEquals(mymodel.mse(), l2_train.mean(), 1e-7);
          Assert.assertTrue("too big a reconstruction error: " + l2_train.mean(), l2_train.mean() < 0.06);

          // manually compute L2
          Frame reconstr = mymodel.score(train); //this creates real values in original space
          double mean_l2 = 0;
          for (int r = 0; r < reconstr.numRows(); ++r) {
            double my_l2 = 0;
            for (int c = 0; c < reconstr.numCols(); ++c) {
              my_l2 += Math.pow((reconstr.vec(c).at(r) - train.vec(c).at(r)) * mymodel.model_info().data_info()._normMul[c], 2); //undo normalization here
            }
            my_l2 /= reconstr.numCols();
            mean_l2 += my_l2;
          }
          mean_l2 /= reconstr.numRows();
          reconstr.delete();
          sb.append("Mean reconstruction error (train): " + l2_train.mean() + "\n");
          Assert.assertEquals(mymodel.mse(), mean_l2, 1e-7);

          // print stats and potential outliers
          sb.append("The following training points are reconstructed with an error above the " + quantile * 100 + "-th percentile - check for \"goodness\" of training data.\n");
          double thresh_train = mymodel.calcOutlierThreshold(l2_train, quantile);
          for (long i = 0; i < l2_train.length(); i++) {
            if (l2_train.at(i) > thresh_train) {
              sb.append(String.format("row %d : l2_train error = %5f\n", i, l2_train.at(i)));
            }
          }

          // Test data

          // Reconstruct data using the same helper functions and verify that self-reported MSE agrees
          l2_frame_test = mymodel.scoreAutoEncoder(test);
          final Vec l2_test = l2_frame_test.anyVec();
          double mult = 10;
          double thresh_test = mult * thresh_train;
          sb.append("\nFinding outliers.\n");
          sb.append("Mean reconstruction error (test): " + l2_test.mean() + "\n");

          // print stats and potential outliers
          sb.append("The following test points are reconstructed with an error greater than " + mult + " times the mean reconstruction error of the training data:\n");
          HashSet<Long> outliers = new HashSet<Long>();
          for (long i = 0; i < l2_test.length(); i++) {
            if (l2_test.at(i) > thresh_test) {
              outliers.add(i);
              sb.append(String.format("row %d : l2 error = %5f\n", i, l2_test.at(i)));
            }
          }

          // check that the all outliers are found (and nothing else)
          Assert.assertTrue(outliers.contains(new Long(20)));
          Assert.assertTrue(outliers.contains(new Long(21)));
          Assert.assertTrue(outliers.contains(new Long(22)));
          Assert.assertTrue(outliers.size() == 3);
        } finally {
          Log.info(sb);
          // cleanup
          if (p!=null) p.delete();
          if (mymodel!=null) mymodel.delete();
          if (l2_frame_train!=null) l2_frame_train.delete();
          if (l2_frame_test!=null) l2_frame_test.delete();
        }
      }
    } finally {
      if (train!=null) train.delete();
      if (test!=null) test.delete();
    }
  }
}

